{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Experiment: \n",
    "\n",
    "Evaluate pruning by magnitude weighted by coactivations.\n",
    "\n",
    "#### Motivation.\n",
    "\n",
    "Test new proposed method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import tabulate\n",
    "import pprint\n",
    "import click\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ray.tune.commands import *\n",
    "from dynamic_sparse.common.browser import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and check data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps = ['improved_magpruning_test1', ]\n",
    "paths = [os.path.expanduser(\"~/nta/results/{}\".format(e)) for e in exps]\n",
    "df = load_many(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Experiment Name</th>\n",
       "      <th>train_acc_max</th>\n",
       "      <th>train_acc_max_epoch</th>\n",
       "      <th>train_acc_min</th>\n",
       "      <th>train_acc_min_epoch</th>\n",
       "      <th>train_acc_median</th>\n",
       "      <th>train_acc_last</th>\n",
       "      <th>val_acc_max</th>\n",
       "      <th>val_acc_max_epoch</th>\n",
       "      <th>val_acc_min</th>\n",
       "      <th>...</th>\n",
       "      <th>momentum</th>\n",
       "      <th>network</th>\n",
       "      <th>num_classes</th>\n",
       "      <th>on_perc</th>\n",
       "      <th>optim_alg</th>\n",
       "      <th>pruning_early_stop</th>\n",
       "      <th>test_noise</th>\n",
       "      <th>use_kwinners</th>\n",
       "      <th>weight_decay</th>\n",
       "      <th>weight_prune_perc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0_weight_prune_perc=None</td>\n",
       "      <td>0.988317</td>\n",
       "      <td>28</td>\n",
       "      <td>0.925067</td>\n",
       "      <td>0</td>\n",
       "      <td>0.985508</td>\n",
       "      <td>0.987333</td>\n",
       "      <td>0.9761</td>\n",
       "      <td>17</td>\n",
       "      <td>0.9628</td>\n",
       "      <td>...</td>\n",
       "      <td>0.9</td>\n",
       "      <td>MLPHeb</td>\n",
       "      <td>10</td>\n",
       "      <td>0.2</td>\n",
       "      <td>SGD</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1_weight_prune_perc=0.1</td>\n",
       "      <td>0.992317</td>\n",
       "      <td>29</td>\n",
       "      <td>0.926417</td>\n",
       "      <td>0</td>\n",
       "      <td>0.988500</td>\n",
       "      <td>0.992317</td>\n",
       "      <td>0.9803</td>\n",
       "      <td>25</td>\n",
       "      <td>0.9562</td>\n",
       "      <td>...</td>\n",
       "      <td>0.9</td>\n",
       "      <td>MLPHeb</td>\n",
       "      <td>10</td>\n",
       "      <td>0.2</td>\n",
       "      <td>SGD</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2_weight_prune_perc=0.2</td>\n",
       "      <td>0.992067</td>\n",
       "      <td>24</td>\n",
       "      <td>0.925400</td>\n",
       "      <td>0</td>\n",
       "      <td>0.988567</td>\n",
       "      <td>0.991500</td>\n",
       "      <td>0.9793</td>\n",
       "      <td>28</td>\n",
       "      <td>0.9598</td>\n",
       "      <td>...</td>\n",
       "      <td>0.9</td>\n",
       "      <td>MLPHeb</td>\n",
       "      <td>10</td>\n",
       "      <td>0.2</td>\n",
       "      <td>SGD</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3_weight_prune_perc=0.3</td>\n",
       "      <td>0.991183</td>\n",
       "      <td>29</td>\n",
       "      <td>0.927700</td>\n",
       "      <td>0</td>\n",
       "      <td>0.987950</td>\n",
       "      <td>0.991183</td>\n",
       "      <td>0.9814</td>\n",
       "      <td>24</td>\n",
       "      <td>0.9612</td>\n",
       "      <td>...</td>\n",
       "      <td>0.9</td>\n",
       "      <td>MLPHeb</td>\n",
       "      <td>10</td>\n",
       "      <td>0.2</td>\n",
       "      <td>SGD</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4_weight_prune_perc=0.4</td>\n",
       "      <td>0.990383</td>\n",
       "      <td>24</td>\n",
       "      <td>0.925767</td>\n",
       "      <td>0</td>\n",
       "      <td>0.986525</td>\n",
       "      <td>0.990117</td>\n",
       "      <td>0.9787</td>\n",
       "      <td>21</td>\n",
       "      <td>0.9632</td>\n",
       "      <td>...</td>\n",
       "      <td>0.9</td>\n",
       "      <td>MLPHeb</td>\n",
       "      <td>10</td>\n",
       "      <td>0.2</td>\n",
       "      <td>SGD</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>0.0001</td>\n",
       "      <td>0.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 41 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            Experiment Name  train_acc_max  train_acc_max_epoch  \\\n",
       "0  0_weight_prune_perc=None       0.988317                   28   \n",
       "1   1_weight_prune_perc=0.1       0.992317                   29   \n",
       "2   2_weight_prune_perc=0.2       0.992067                   24   \n",
       "3   3_weight_prune_perc=0.3       0.991183                   29   \n",
       "4   4_weight_prune_perc=0.4       0.990383                   24   \n",
       "\n",
       "   train_acc_min  train_acc_min_epoch  train_acc_median  train_acc_last  \\\n",
       "0       0.925067                    0          0.985508        0.987333   \n",
       "1       0.926417                    0          0.988500        0.992317   \n",
       "2       0.925400                    0          0.988567        0.991500   \n",
       "3       0.927700                    0          0.987950        0.991183   \n",
       "4       0.925767                    0          0.986525        0.990117   \n",
       "\n",
       "   val_acc_max  val_acc_max_epoch  val_acc_min  ...  momentum  network  \\\n",
       "0       0.9761                 17       0.9628  ...       0.9   MLPHeb   \n",
       "1       0.9803                 25       0.9562  ...       0.9   MLPHeb   \n",
       "2       0.9793                 28       0.9598  ...       0.9   MLPHeb   \n",
       "3       0.9814                 24       0.9612  ...       0.9   MLPHeb   \n",
       "4       0.9787                 21       0.9632  ...       0.9   MLPHeb   \n",
       "\n",
       "   num_classes  on_perc optim_alg  pruning_early_stop  test_noise  \\\n",
       "0           10      0.2       SGD                   0       False   \n",
       "1           10      0.2       SGD                   0       False   \n",
       "2           10      0.2       SGD                   0       False   \n",
       "3           10      0.2       SGD                   0       False   \n",
       "4           10      0.2       SGD                   0       False   \n",
       "\n",
       "   use_kwinners weight_decay weight_prune_perc  \n",
       "0         False       0.0001               NaN  \n",
       "1         False       0.0001               0.1  \n",
       "2         False       0.0001               0.2  \n",
       "3         False       0.0001               0.3  \n",
       "4         False       0.0001               0.4  \n",
       "\n",
       "[5 rows x 41 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# replace hebbian prine\n",
    "df['hebbian_prune_perc'] = df['hebbian_prune_perc'].replace(np.nan, 0.0, regex=True)\n",
    "df['weight_prune_perc'] = df['weight_prune_perc'].replace(np.nan, 0.0, regex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['Experiment Name', 'train_acc_max', 'train_acc_max_epoch',\n",
       "       'train_acc_min', 'train_acc_min_epoch', 'train_acc_median',\n",
       "       'train_acc_last', 'val_acc_max', 'val_acc_max_epoch', 'val_acc_min',\n",
       "       'val_acc_min_epoch', 'val_acc_median', 'val_acc_last', 'epochs',\n",
       "       'experiment_file_name', 'trial_time', 'mean_epoch_time', 'batch_norm',\n",
       "       'data_dir', 'dataset_name', 'debug_sparse', 'debug_weights', 'device',\n",
       "       'hebbian_prune_perc', 'hidden_sizes', 'input_size', 'learning_rate',\n",
       "       'lr_gamma', 'lr_milestones', 'lr_scheduler', 'model', 'momentum',\n",
       "       'network', 'num_classes', 'on_perc', 'optim_alg', 'pruning_early_stop',\n",
       "       'test_noise', 'use_kwinners', 'weight_decay', 'weight_prune_perc'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24, 41)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Experiment Name                                   1_weight_prune_perc=0.1\n",
       "train_acc_max                                                    0.992317\n",
       "train_acc_max_epoch                                                    29\n",
       "train_acc_min                                                    0.926417\n",
       "train_acc_min_epoch                                                     0\n",
       "train_acc_median                                                   0.9885\n",
       "train_acc_last                                                   0.992317\n",
       "val_acc_max                                                        0.9803\n",
       "val_acc_max_epoch                                                      25\n",
       "val_acc_min                                                        0.9562\n",
       "val_acc_min_epoch                                                       0\n",
       "val_acc_median                                                     0.9769\n",
       "val_acc_last                                                       0.9764\n",
       "epochs                                                                 30\n",
       "experiment_file_name    /Users/lsouza/nta/results/improved_magpruning_...\n",
       "trial_time                                                        17.3449\n",
       "mean_epoch_time                                                  0.578165\n",
       "batch_norm                                                           True\n",
       "data_dir                                        /home/ubuntu/nta/datasets\n",
       "dataset_name                                                        MNIST\n",
       "debug_sparse                                                         True\n",
       "debug_weights                                                        True\n",
       "device                                                               cuda\n",
       "hebbian_prune_perc                                                      0\n",
       "hidden_sizes                                                          100\n",
       "input_size                                                            784\n",
       "learning_rate                                                         0.1\n",
       "lr_gamma                                                              0.1\n",
       "lr_milestones                                                          60\n",
       "lr_scheduler                                                  MultiStepLR\n",
       "model                                                     DSNNWeightedMag\n",
       "momentum                                                              0.9\n",
       "network                                                            MLPHeb\n",
       "num_classes                                                            10\n",
       "on_perc                                                               0.2\n",
       "optim_alg                                                             SGD\n",
       "pruning_early_stop                                                      0\n",
       "test_noise                                                          False\n",
       "use_kwinners                                                        False\n",
       "weight_decay                                                       0.0001\n",
       "weight_prune_perc                                                     0.1\n",
       "Name: 1, dtype: object"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model\n",
       "DSNNWeightedMag    24\n",
       "Name: model, dtype: int64"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.groupby('model')['model'].count()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experiment Details"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "base_exp_config = dict(\n",
    "    device=\"cuda\",\n",
    "    # dataset related\n",
    "    dataset_name=\"MNIST\",\n",
    "    data_dir=os.path.expanduser(\"~/nta/datasets\"),\n",
    "    input_size=784,\n",
    "    num_classes=10,\n",
    "    # network related\n",
    "    network=\"MLPHeb\",\n",
    "    hidden_sizes=[100, 100, 100],\n",
    "    batch_norm=True,\n",
    "    use_kwinners=False,\n",
    "    # model related\n",
    "    model=\"DSNNWeightedMag\",\n",
    "    on_perc=0.2,\n",
    "    optim_alg=\"SGD\",\n",
    "    momentum=0.9,\n",
    "    weight_decay=1e-4,    \n",
    "    learning_rate=0.1,\n",
    "    lr_scheduler=\"MultiStepLR\",\n",
    "    lr_milestones=[30,60,90],\n",
    "    lr_gamma=0.1,\n",
    "    # sparse related\n",
    "    hebbian_prune_perc=None,\n",
    "    weight_prune_perc=tune.grid_search([None, 0.1, 0.2, 0.3, 0.4, 0.5]),\n",
    "    pruning_early_stop=0,\n",
    "    # additional validation\n",
    "    test_noise=False,\n",
    "    # debugging\n",
    "    debug_weights=True,\n",
    "    debug_sparse=True,\n",
    ")\n",
    "\n",
    "# ray configurations\n",
    "tune_config = dict(\n",
    "    name=__file__.replace(\".py\", \"\") + \"_test1\",\n",
    "    num_samples=4,\n",
    "    local_dir=os.path.expanduser(\"~/nta/results\"),\n",
    "    checkpoint_freq=0,\n",
    "    checkpoint_at_end=False,\n",
    "    stop={\"training_iteration\": 30},\n",
    "    resources_per_trial={\"cpu\": 1, \"gpu\": 0.165},\n",
    "    loggers=DEFAULT_LOGGERS,\n",
    "    verbose=0,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Did any  trials failed?\n",
    "df[df[\"epochs\"]<30][\"epochs\"].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24, 41)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Removing failed or incomplete trials\n",
    "df_origin = df.copy()\n",
    "df = df_origin[df_origin[\"epochs\"]>=30]\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Series([], Name: epochs, dtype: int64)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# which ones failed?\n",
    "# failed, or still ongoing?\n",
    "df_origin['failed'] = df_origin[\"epochs\"]<30\n",
    "df_origin[df_origin['failed']]['epochs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper functions\n",
    "def mean_and_std(s):\n",
    "    return \"{:.3f} ± {:.3f}\".format(s.mean(), s.std())\n",
    "\n",
    "def round_mean(s):\n",
    "    return \"{:.0f}\".format(round(s.mean()))\n",
    "\n",
    "stats = ['min', 'max', 'mean', 'std']\n",
    "\n",
    "def agg(columns, filter=None, round=3):\n",
    "    if filter is None:\n",
    "        return (df.groupby(columns)\n",
    "             .agg({'val_acc_max_epoch': round_mean,\n",
    "                   'val_acc_max': stats,                \n",
    "                   'model': ['count']})).round(round)\n",
    "    else:\n",
    "        return (df[filter].groupby(columns)\n",
    "             .agg({'val_acc_max_epoch': round_mean,\n",
    "                   'val_acc_max': stats,                \n",
    "                   'model': ['count']})).round(round)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### What are optimal levels of hebbian and weight pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>val_acc_max_epoch</th>\n",
       "      <th colspan=\"4\" halign=\"left\">val_acc_max</th>\n",
       "      <th>model</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>round_mean</th>\n",
       "      <th>min</th>\n",
       "      <th>max</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>weight_prune_perc</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>20</td>\n",
       "      <td>0.976</td>\n",
       "      <td>0.977</td>\n",
       "      <td>0.977</td>\n",
       "      <td>0.000</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.1</th>\n",
       "      <td>22</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.982</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.2</th>\n",
       "      <td>27</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.3</th>\n",
       "      <td>24</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.981</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>24</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.981</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.5</th>\n",
       "      <td>23</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  val_acc_max_epoch val_acc_max                      model\n",
       "                         round_mean         min    max   mean    std count\n",
       "weight_prune_perc                                                         \n",
       "0.0                              20       0.976  0.977  0.977  0.000     4\n",
       "0.1                              22       0.979  0.982  0.980  0.001     4\n",
       "0.2                              27       0.978  0.979  0.979  0.001     4\n",
       "0.3                              24       0.980  0.981  0.980  0.001     4\n",
       "0.4                              24       0.979  0.981  0.980  0.001     4\n",
       "0.5                              23       0.978  0.980  0.979  0.001     4"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg(['weight_prune_perc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr:last-of-type th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>val_acc_max_epoch</th>\n",
       "      <th colspan=\"4\" halign=\"left\">val_acc_max</th>\n",
       "      <th>model</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>round_mean</th>\n",
       "      <th>min</th>\n",
       "      <th>max</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>weight_prune_perc</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>20</td>\n",
       "      <td>0.976</td>\n",
       "      <td>0.977</td>\n",
       "      <td>0.977</td>\n",
       "      <td>0.000</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.2</th>\n",
       "      <td>27</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>24</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.981</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.001</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  val_acc_max_epoch val_acc_max                      model\n",
       "                         round_mean         min    max   mean    std count\n",
       "weight_prune_perc                                                         \n",
       "0.0                              20       0.976  0.977  0.977  0.000     4\n",
       "0.2                              27       0.978  0.979  0.979  0.001     4\n",
       "0.4                              24       0.979  0.981  0.980  0.001     4"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multi2 = (df['weight_prune_perc'] % 0.2 == 0)\n",
    "agg(['weight_prune_perc'], multi2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* No relevant difference "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>weight_prune_perc</th>\n",
       "      <th>0.0</th>\n",
       "      <th>0.2</th>\n",
       "      <th>0.4</th>\n",
       "      <th>0.6</th>\n",
       "      <th>0.8</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hebbian_prune_perc</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0.0</th>\n",
       "      <td>0.976 ± 0.000</td>\n",
       "      <td>0.981 ± 0.001</td>\n",
       "      <td>0.980 ± 0.000</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.2</th>\n",
       "      <td>0.974 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "      <td>0.979 ± 0.003</td>\n",
       "      <td>0.979 ± 0.002</td>\n",
       "      <td>0.977 ± 0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.4</th>\n",
       "      <td>0.971 ± 0.001</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "      <td>0.978 ± 0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.6</th>\n",
       "      <td>0.969 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.979 ± 0.001</td>\n",
       "      <td>0.978 ± 0.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0.8</th>\n",
       "      <td>0.966 ± 0.001</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.980 ± 0.002</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.978 ± 0.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1.0</th>\n",
       "      <td>0.964 ± 0.001</td>\n",
       "      <td>0.981 ± 0.001</td>\n",
       "      <td>0.981 ± 0.001</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "      <td>0.980 ± 0.001</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "weight_prune_perc             0.0            0.2            0.4  \\\n",
       "hebbian_prune_perc                                                \n",
       "0.0                 0.976 ± 0.000  0.981 ± 0.001  0.980 ± 0.000   \n",
       "0.2                 0.974 ± 0.001  0.979 ± 0.001  0.979 ± 0.003   \n",
       "0.4                 0.971 ± 0.001  0.980 ± 0.001  0.979 ± 0.001   \n",
       "0.6                 0.969 ± 0.001  0.979 ± 0.001  0.980 ± 0.001   \n",
       "0.8                 0.966 ± 0.001  0.980 ± 0.001  0.980 ± 0.002   \n",
       "1.0                 0.964 ± 0.001  0.981 ± 0.001  0.981 ± 0.001   \n",
       "\n",
       "weight_prune_perc             0.6            0.8  \n",
       "hebbian_prune_perc                                \n",
       "0.0                 0.980 ± 0.001  0.979 ± 0.001  \n",
       "0.2                 0.979 ± 0.002  0.977 ± 0.001  \n",
       "0.4                 0.979 ± 0.001  0.978 ± 0.001  \n",
       "0.6                 0.979 ± 0.001  0.978 ± 0.001  \n",
       "0.8                 0.980 ± 0.001  0.978 ± 0.000  \n",
       "1.0                 0.980 ± 0.001  0.980 ± 0.001  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.pivot_table(df[filter], \n",
    "              index='hebbian_prune_perc',\n",
    "              columns='weight_prune_perc',\n",
    "              values='val_acc_max',\n",
    "              aggfunc=mean_and_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(108, 42)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Conclusions:\n",
    "- No pruning leads (0,0) to acc of 0.976\n",
    "- Pruning all connections at every epoch (1,0) leads to acc of 0.964\n",
    "- Best performing model is still no hebbian pruning, and weight pruning set to 0.2 (0.981)\n",
    "- Pruning only by hebbian learning decreases accuracy\n",
    "- Combining hebbian and weight magnitude is not an improvement compared to simple weight magnitude pruning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
